/******************************************************************************
 * Copyright 2018 The Apollo Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *****************************************************************************/

/**
 * @file
  Modification: Only some functions are referenced
**/

#include "dp_st_graph.h"
namespace
{
    static constexpr double kDoubleEpsilon = 1.0e-6;

    // Continuous-time collision check using linear interpolation as closed-loop
    // dynamics
    bool CheckOverlapOnDpStGraph(const std::vector<ST_Boundary> &boundaries, const StGraphPoint &p1, const StGraphPoint &p2)
    {
        if (Config_.FLAGS_use_st_drivable_boundary)
        {
            return false;
        }
        for (const auto &boundary : boundaries)
        {
            if (boundary.boundary_type() == ST_Boundary::BoundaryType::KEEP_CLEAR)
            {
                continue;
            }
            // Check collision between a polygon and a line segment
            if (boundary.HasOverlap({p1.point(), p2.point()}))
            {
                return true;
            }
        }
        return false;
    }
} // namespace

DpStGraph::DpStGraph(const StGraphData &st_graph_data,
                     const std::vector<const Obstacle *> &obstacles,
                     const TrajectoryPoint &init_point,
                     const SL_Boundary &adc_sl_boundary)
    : st_graph_data_(st_graph_data),
      obstacles_(obstacles),
      init_point_(init_point),
      dp_st_cost_(st_graph_data_.total_time_by_conf(), st_graph_data_.path_length(), obstacles, init_point_),
      adc_sl_boundary_(adc_sl_boundary)
{
    unit_s_ = Config_.total_path_length /
              (Config_.matrix_dimension_s - 1); //matrix_dimension_s
    unit_t_ = Config_.total_time /
              (Config_.matrix_dimension_t - 1); //matrix_dimension_t
}

bool DpStGraph::Search(SpeedData *const speed_data)
{
    constexpr double kBounadryEpsilon = 1e-2;
    for (const auto &boundary : st_graph_data_.st_boundaries())
    {
        if (boundary.boundary_type() == ST_Boundary::BoundaryType::KEEP_CLEAR)
        {
            continue;
        }
        if (boundary.IsPointInBoundary({0.0, 0.0}) ||
            (std::fabs(boundary.min_t_) < kBounadryEpsilon &&
             std::fabs(boundary.min_s_) < kBounadryEpsilon))
        {
            std::vector<SpeedPoint> speed_profile;
            double t = 0.0;
            for (int i = 0; i <= Config_.matrix_dimension_t;
                 ++i, t += unit_t_)
            {
                SpeedPoint speed_point;
                speed_point.set_s(0.0);
                speed_point.set_t(t);
                speed_point.set_v(0.0);
                speed_point.set_a(0.0);
                speed_profile.emplace_back(speed_point);
            }
            *speed_data = SpeedData(speed_profile);
            return true;
        }
    }

    if (st_graph_data_.st_boundaries().empty())
    {
        // ADEBUG << "No path obstacles, dp_st_graph output default speed profile.";
        std::vector<SpeedPoint> speed_profile;
        double s = 0.0;
        double t = 0.0;
        for (int i = 0; i <= Config_.matrix_dimension_t &&
                        i <= Config_.matrix_dimension_s;
             ++i, t += unit_t_, s += unit_s_)
        {
            SpeedPoint speed_point;
            speed_point.set_s(s);
            speed_point.set_t(t);
            const double v_default = unit_s_ / unit_t_;
            speed_point.set_v(v_default);
            speed_point.set_a(0.0);
            speed_profile.emplace_back(std::move(speed_point));
        }
        *speed_data = SpeedData(std::move(speed_profile));
        return true;
    }

    if (!InitCostTable())
    {
        const std::string msg = "Initialize cost table failed.";
        // AERROR << msg;
        return false;
    }

    if (!CalculateTotalCost())
    {
        const std::string msg = "Calculate total cost failed.";
        //AERROR << msg;
        return false;
    }

    if (!RetrieveSpeedProfile(speed_data))
    {
        const std::string msg = "Retrieve best speed profile failed.";
        // AERROR << msg;
        return false;
    }
    return true;
}

bool DpStGraph::InitCostTable()
{
    uint32_t dim_s = Config_.matrix_dimension_s;
    uint32_t dim_t = Config_.matrix_dimension_t;
    if (dim_s < 2)
        return false;
    if (dim_t < 2)
        return false;

    cost_table_ = std::vector<std::vector<StGraphPoint>>(
        dim_t, std::vector<StGraphPoint>(dim_s, StGraphPoint()));

    double curr_t = 0.0;
    for (uint32_t i = 0; i < cost_table_.size(); ++i, curr_t += unit_t_)
    {
        auto &cost_table_i = cost_table_[i];
        double curr_s = 0.0;
        for (uint32_t j = 0; j < cost_table_i.size(); ++j, curr_s += unit_s_)
        {
            cost_table_i[j].Init(i, j, STPoint(curr_s, curr_t));
        }
    }
    return true;
}

bool DpStGraph::CalculateTotalCost()
{
    // col and row are for STGraph
    // t corresponding to col
    // s corresponding to row
    size_t next_highest_row = 0;
    size_t next_lowest_row = 0;

    for (size_t c = 0; c < cost_table_.size(); ++c)
    {
        size_t highest_row = 0;
        auto lowest_row = cost_table_.back().size() - 1;

        int count = static_cast<int>(next_highest_row) -
                    static_cast<int>(next_lowest_row) + 1;
        if (count > 0)
        {
            for (size_t r = next_lowest_row; r <= next_highest_row; ++r)
            {
                auto msg = std::make_shared<StGraphMessage>(c, r);
                CalculateCostAt(msg);
            }
        }

        for (size_t r = next_lowest_row; r <= next_highest_row; ++r)
        {
            const auto &cost_cr = cost_table_[c][r];
            if (cost_cr.total_cost() < std::numeric_limits<double>::infinity())
            {
                size_t h_r = 0;
                size_t l_r = 0;
                GetRowRange(cost_cr, &h_r, &l_r);
                highest_row = std::max(highest_row, h_r);
                lowest_row = std::min(lowest_row, static_cast<size_t>(l_r));
            }
        }
        next_highest_row = highest_row;
        next_lowest_row = lowest_row;
    }

    return true;
}

void DpStGraph::GetRowRange(const StGraphPoint &point, size_t *next_highest_row,
                            size_t *next_lowest_row)
{
    double v0 = 0.0;
    if (!point.pre_point())
    {
        v0 = init_point_.v;
    }
    else
    {
        v0 = (point.index_s() - point.pre_point()->index_s()) * unit_s_ / unit_t_;
    }

    const auto max_s_size = cost_table_.back().size() - 1;

    const double speed_coeff = unit_t_ * unit_t_;

    const double delta_s_upper_bound =
        v0 * unit_t_ + Config_.max_acceleration * speed_coeff;
    *next_highest_row =
        point.index_s() + static_cast<int>(delta_s_upper_bound / unit_s_);
    if (*next_highest_row >= max_s_size)
    {
        *next_highest_row = max_s_size;
    }

    const double delta_s_lower_bound = std::fmax(
        0.0, v0 * unit_t_ + Config_.max_deceleration * speed_coeff);
    *next_lowest_row =
        point.index_s() + static_cast<int>(delta_s_lower_bound / unit_s_);
    if (*next_lowest_row > max_s_size)
    {
        *next_lowest_row = max_s_size;
    }
    else if (*next_lowest_row < 0)
    {
        *next_lowest_row = 0;
    }
}

void DpStGraph::CalculateCostAt(const std::shared_ptr<StGraphMessage> &msg)
{
    const uint32_t c = msg->c;
    const uint32_t r = msg->r;
    auto &cost_cr = cost_table_[c][r];
    cost_cr.SetObstacleCost(dp_st_cost_.GetObstacleCost(cost_cr));
    if (cost_cr.obstacle_cost() > std::numeric_limits<double>::max())
    {
        return;
    }

    const auto &cost_init = cost_table_[0][0];
    if (c == 0)
    {
        // DCHECK_EQ(r, 0) << "Incorrect. Row should be 0 with col = 0. row: " << r;
        cost_cr.SetTotalCost(0.0);
        return;
    }

    double speed_limit =
        st_graph_data_.speed_limit().GetSpeedLimitByS(unit_s_ * r);
    if (c == 1)
    {
        const double acc = (r * unit_s_ / unit_t_ - init_point_.v) / unit_t_;
        if (acc < Config_.max_deceleration ||
            acc > Config_.max_acceleration)
        {
            return;
        }

        if (CheckOverlapOnDpStGraph(st_graph_data_.st_boundaries(), cost_cr,
                                    cost_init))
        {
            return;
        }
        cost_cr.SetTotalCost(cost_cr.obstacle_cost() + cost_init.total_cost() +
                             CalculateEdgeCostForSecondCol(r, speed_limit));
        cost_cr.SetPrePoint(cost_init);
        return;
    }

    constexpr double kSpeedRangeBuffer = 0.20;
    const uint32_t max_s_diff =
        static_cast<uint32_t>(Config_.planning_upper_speed_limit *
                              (1 + kSpeedRangeBuffer) * unit_t_ / unit_s_);
    const uint32_t r_low = (max_s_diff < r ? r - max_s_diff : 0);

    const auto &pre_col = cost_table_[c - 1];

    if (c == 2)
    {
        for (uint32_t r_pre = r_low; r_pre <= r; ++r_pre)
        {
            const double acc =
                (r * unit_s_ - 2 * r_pre * unit_s_) / (unit_t_ * unit_t_);
            if (acc < Config_.max_deceleration ||
                acc > Config_.max_acceleration)
            {
                continue;
            }

            if (CheckOverlapOnDpStGraph(st_graph_data_.st_boundaries(), cost_cr,
                                        pre_col[r_pre]))
            {
                continue;
            }

            const double cost = cost_cr.obstacle_cost() +
                                pre_col[r_pre].total_cost() +
                                CalculateEdgeCostForThirdCol(r, r_pre, speed_limit);

            if (cost < cost_cr.total_cost())
            {
                cost_cr.SetTotalCost(cost);
                cost_cr.SetPrePoint(pre_col[r_pre]);
            }
        }
        return;
    }
    for (uint32_t r_pre = r_low; r_pre <= r; ++r_pre)
    {
        if (std::isinf(pre_col[r_pre].total_cost()) ||
            pre_col[r_pre].pre_point() == nullptr)
        {
            continue;
        }

        const double curr_a = (cost_cr.index_s() * unit_s_ +
                               pre_col[r_pre].pre_point()->index_s() * unit_s_ -
                               2 * pre_col[r_pre].index_s() * unit_s_) /
                              (unit_t_ * unit_t_);
        if (curr_a > Config_.max_acceleration ||
            curr_a < Config_.max_deceleration)
        {
            continue;
        }
        if (CheckOverlapOnDpStGraph(st_graph_data_.st_boundaries(), cost_cr,
                                    pre_col[r_pre]))
        {
            continue;
        }

        uint32_t r_prepre = pre_col[r_pre].pre_point()->index_s();
        const StGraphPoint &prepre_graph_point = cost_table_[c - 2][r_prepre];
        if (std::isinf(prepre_graph_point.total_cost()))
        {
            continue;
        }

        if (!prepre_graph_point.pre_point())
        {
            continue;
        }
        const STPoint &triple_pre_point = prepre_graph_point.pre_point()->point();
        const STPoint &prepre_point = prepre_graph_point.point();
        const STPoint &pre_point = pre_col[r_pre].point();
        const STPoint &curr_point = cost_cr.point();
        double cost = cost_cr.obstacle_cost() + pre_col[r_pre].total_cost() +
                      CalculateEdgeCost(triple_pre_point, prepre_point, pre_point,
                                        curr_point, speed_limit);

        if (cost < cost_cr.total_cost())
        {
            cost_cr.SetTotalCost(cost);
            cost_cr.SetPrePoint(pre_col[r_pre]);
        }
    }
}

bool DpStGraph::RetrieveSpeedProfile(SpeedData *const speed_data)
{
    double min_cost = std::numeric_limits<double>::infinity();
    const StGraphPoint *best_end_point = nullptr;
    for (const StGraphPoint &cur_point : cost_table_.back())
    {
        if (!std::isinf(cur_point.total_cost()) &&
            cur_point.total_cost() < min_cost)
        {
            best_end_point = &cur_point;
            min_cost = cur_point.total_cost();
        }
    }

    for (const auto &row : cost_table_)
    {
        const StGraphPoint &cur_point = row.back();
        if (!std::isinf(cur_point.total_cost()) &&
            cur_point.total_cost() < min_cost)
        {
            best_end_point = &cur_point;
            min_cost = cur_point.total_cost();
        }
    }

    if (best_end_point == nullptr)
    {
        const std::string msg = "Fail to find the best feasible trajectory.";
        // AERROR << msg;
        return false;
    }

    std::vector<SpeedPoint> speed_profile;
    const StGraphPoint *cur_point = best_end_point;
    while (cur_point != nullptr)
    {
        SpeedPoint speed_point;
        speed_point.set_s(cur_point->point().s());
        speed_point.set_t(cur_point->point().t());
        speed_profile.emplace_back(speed_point);
        cur_point = cur_point->pre_point();
    }
    std::reverse(speed_profile.begin(), speed_profile.end());

    constexpr double kEpsilon = std::numeric_limits<double>::epsilon();
    if (speed_profile.front().t > kEpsilon ||
        speed_profile.front().s > kEpsilon)
    {
        const std::string msg = "Fail to retrieve speed profile.";
        // AERROR << msg;
        return false;
    }

    for (size_t i = 0; i + 1 < speed_profile.size(); ++i)
    {
        const double v = (speed_profile[i + 1].s - speed_profile[i].s) /
                         (speed_profile[i + 1].t - speed_profile[i].t + 1e-3);
        speed_profile[i].set_v(v);
    }

    *speed_data = SpeedData(speed_profile);
    return true;
}

double DpStGraph::CalculateEdgeCost(const STPoint &first, const STPoint &second,
                                    const STPoint &third, const STPoint &forth,
                                    const double speed_limit)
{
    return dp_st_cost_.GetSpeedCost(third, forth, speed_limit, Config_.default_cruise_speed) +
           dp_st_cost_.GetAccelCostByThreePoints(second, third, forth) +
           dp_st_cost_.GetJerkCostByFourPoints(first, second, third, forth);
}

double DpStGraph::CalculateEdgeCostForSecondCol(const uint32_t row,
                                                const double speed_limit)
{
    double init_speed = init_point_.v;
    double init_acc = init_point_.a;
    const STPoint &pre_point = cost_table_[0][0].point();
    const STPoint &curr_point = cost_table_[1][row].point();
    return dp_st_cost_.GetSpeedCost(pre_point, curr_point, speed_limit, Config_.default_cruise_speed) +
           dp_st_cost_.GetAccelCostByTwoPoints(init_speed, pre_point,
                                               curr_point) +
           dp_st_cost_.GetJerkCostByTwoPoints(init_speed, init_acc, pre_point,
                                              curr_point);
}

double DpStGraph::CalculateEdgeCostForThirdCol(const uint32_t curr_row,
                                               const uint32_t pre_row,
                                               const double speed_limit)
{
    double init_speed = init_point_.v;
    const STPoint &first = cost_table_[0][0].point();
    const STPoint &second = cost_table_[1][pre_row].point();
    const STPoint &third = cost_table_[2][curr_row].point();
    return dp_st_cost_.GetSpeedCost(second, third, speed_limit, Config_.default_cruise_speed) +
           dp_st_cost_.GetAccelCostByThreePoints(first, second, third) +
           dp_st_cost_.GetJerkCostByThreePoints(init_speed, first, second, third);
}